import os
import torch
import requests
import safetensors.torch
import numpy as np
from io import BytesIO

import folder_paths

class LoadLatentNumpy:
	def __init__(self):
		pass

	@classmethod
	def INPUT_TYPES(s):
		exts = [".latent", ".safetensors", ".npy", ".npz"]
		input_dir = folder_paths.get_input_directory()
		files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
		files = [f for f in files if any([f.endswith(x) for x in exts])]
		return {
			"required": {
				"latent": [sorted(files), ]
			},
		}

	RETURN_TYPES = ("LATENT",)
	FUNCTION = "load"
	CATEGORY = "remote/latent"
	TITLE = "Load Latent (Numpy)"

	def load_comfy(self, file):
		# From default node - renamed safetensors file
		if type(file) == str:
			data = safetensors.torch.load_file(file)
		else:
			data = safetensors.torch.load(file)

		latent = data["latent_tensor"].to(torch.float32)
		if "latent_format_version_0" not in data:
			latent *= 1.0 / 0.18215 # XL?
		return latent

	def load_numpy(self, file):
		# plain npy file - saved as-is
		return torch.from_numpy(np.load(file))

	def load_koyha(self, file):
		# generated by sd_scripts - npz
		if "latents" in data.keys():
			latent = data["latents"]
		else:
			latent = [x for x in data.items() if x.shape > 3][0]
		return torch.from_numpy(latent)

	def load(self, latent):
		path = folder_paths.get_annotated_filepath(latent)
		name, ext = os.path.splitext(latent)

		if ext in [".latent", ".safetensors"]:
			latent = self.load_comfy(path)
		elif ext == ".npy":
			latent = self.load_numpy(path)
		elif ext == ".npz":
			latent = self.load_koyha(path)
		else:
			try:
				latent = self.load_numpy(path)
			except:
				raise ValueError(f"Unknown latent extension '{ext}'")

		if len(latent.shape) == 3:
			latent = latent.unsqueeze(0)
		print("asdasd", latent.shape)

		return ({"samples": latent.to(torch.float32)},)

	@classmethod
	def IS_CHANGED(s, latent):
		image_path = folder_paths.get_annotated_filepath(latent)
		m = hashlib.sha256()
		with open(image_path, 'rb') as f:
			m.update(f.read())
		return m.digest().hex()

	@classmethod
	def VALIDATE_INPUTS(s, latent):
		if not folder_paths.exists_annotated_filepath(latent):
			return f"Invalid latent file '{latent}'"
		return True

class LoadLatentUrl(LoadLatentNumpy):
	def __init__(self):
		pass

	@classmethod
	def INPUT_TYPES(s):
		return {
			"required": {
				"url": ("STRING", { "multiline": False, })
			}
		}

	RETURN_TYPES = ("LATENT",)
	TITLE = "Load Latent (URL)"

	def load(self, url):
		buffer = BytesIO()
		with requests.get(url, stream=True, timeout=16) as r:
			r.raise_for_status()
			buffer.write(r.content)
		buffer.seek(0)

		if ".latent" in url or ".safetensors" in url:
			latent = self.load_comfy(buffer)
		elif ".npy" in url:
			latent = self.load_numpy(buffer)
		elif ".npz" in url:
			latent = self.load_koyha(buffer)
		else:
			try:
				latent = self.load_comfy(buffer)
			except:
				raise ValueError(f"Unknown latent extension '{url}'")

		if len(latent.shape) == 3:
			latent = latent.unsqueeze(0)

		del buffer
		return ({"samples": latent.to(torch.float32)},)

	@classmethod
	def IS_CHANGED(s, url):
		return str(url)

	@classmethod
	def VALIDATE_INPUTS(s, url):
		return True

class SaveLatentNumpy:
	def __init__(self):
		self.output_dir = folder_paths.get_output_directory()

	@classmethod
	def INPUT_TYPES(s):
		return {
			"required": {
				"samples": ("LATENT",),
				"filename_prefix": ("STRING", {"default": "latents/ComfyUI"})
			}
		}

	RETURN_TYPES = ("STRING",)
	RETURN_NAMES = ("filename",)
	OUTPUT_NODE = True
	FUNCTION = "save"
	CATEGORY = "remote/latent"
	TITLE = "Save Latent (Numpy)"

	def save(self, samples, filename_prefix="ComfyUI"):
		full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
		fname = f"{filename}_{counter:05}_.npy"
		path = os.path.join(full_output_folder, fname)
		np.save(path, samples["samples"].numpy())
		return (fname,)

NODE_CLASS_MAPPINGS = {
	"LoadLatentNumpy" : LoadLatentNumpy,
	"LoadLatentUrl"   : LoadLatentUrl,
	"SaveLatentNumpy" : SaveLatentNumpy,
}
